{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Unsupervised learning with Autoencoder\n",
    "\n",
    "Some piece of codes taken from https://github.com/kevinzakka/vae-pytorch\n",
    "\n",
    "<img src='Autoencoder_structure.png'>\n",
    "\n",
    "Description given by [Wikipedia](https://en.wikipedia.org/wiki/Autoencoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#where your MNIST dataset is stored:\n",
    "data_dir = '/home/lelarge/data'\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST(data_dir, train=True, download=True, transform=transforms.ToTensor()),\n",
    "    batch_size=256, shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST(data_dir, train=False, download=True, transform=transforms.ToTensor()),\n",
    "    batch_size=10, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def to_img(x):\n",
    "    x = x.data.numpy()\n",
    "    x = 0.5 * (x + 1)\n",
    "    x = np.clip(x, 0, 1)\n",
    "    x = x.reshape([-1, 28, 28])\n",
    "    return x\n",
    "\n",
    "def plot_reconstructions(model, conv=False):\n",
    "    \"\"\"\n",
    "    Plot 10 reconstructions from the test set. The top row is the original\n",
    "    digits, the bottom is the decoder reconstruction.\n",
    "    The middle row is the encoded vector.\n",
    "    \"\"\"\n",
    "    # encode then decode\n",
    "    data, _ = next(iter(test_loader))\n",
    "    if not conv:\n",
    "        data = data.view([-1, 784])\n",
    "    data.requires_grad = False\n",
    "    true_imgs = data\n",
    "    encoded_imgs = model.encoder(data)\n",
    "    decoded_imgs = model.decoder(encoded_imgs)\n",
    "    \n",
    "    true_imgs = to_img(true_imgs)\n",
    "    decoded_imgs = to_img(decoded_imgs)\n",
    "    encoded_imgs = encoded_imgs.data.numpy()\n",
    "    \n",
    "    n = 10\n",
    "    plt.figure(figsize=(20, 4))\n",
    "    for i in range(n):\n",
    "        # display original\n",
    "        ax = plt.subplot(3, n, i + 1)\n",
    "        plt.imshow(true_imgs[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "        \n",
    "        ax = plt.subplot(3, n, i + 1 + n)\n",
    "        plt.imshow(encoded_imgs[i].reshape(-1,4))\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "\n",
    "        # display reconstruction\n",
    "        ax = plt.subplot(3, n, i + 1 + n + n)\n",
    "        plt.imshow(decoded_imgs[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "    \n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple Auto-Encoder\n",
    "\n",
    "We'll start with the simplest autoencoder: a single, fully-connected layer as the encoder and decoder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class AutoEncoder(nn.Module):\n",
    "    def __init__(self, input_dim, encoding_dim):\n",
    "        super(AutoEncoder, self).__init__()\n",
    "        self.encoder = nn.Linear(input_dim, encoding_dim)\n",
    "        self.decoder = nn.Linear(encoding_dim, input_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        encoded = F.relu(self.encoder(x))\n",
    "        decoded = self.decoder(encoded)\n",
    "        return decoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "input_dim = 784\n",
    "encoding_dim = 32\n",
    "\n",
    "model = AutoEncoder(input_dim, encoding_dim)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "loss_fn = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Why did we take 784 as input dimension? What is the learning rate?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train_model(model,loss_fn,data_loader=None,epochs=1,optimizer=None):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        for batch_idx, (data, _) in enumerate(train_loader):\n",
    "            \n",
    "            data = data.view([-1, 784])\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = loss_fn(output, data)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if batch_idx % 50 == 0:\n",
    "                print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                    epoch, batch_idx * len(data), len(data_loader.dataset),\n",
    "                    100. * batch_idx / len(data_loader), loss.data.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_model(model, loss_fn,data_loader=train_loader,epochs=15,optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_reconstructions(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you remove the non-linearity, what are you doing?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stacked Auto-Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class DeepAutoEncoder(nn.Module):\n",
    "    def __init__(self, input_dim, encoding_dim):\n",
    "        super(DeepAutoEncoder, self).__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Linear(input_dim, 128),\n",
    "            nn.ReLU(True),\n",
    "            nn.Linear(128, 64),\n",
    "            nn.ReLU(True), \n",
    "            nn.Linear(64, encoding_dim), \n",
    "            nn.ReLU(True),\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.Linear(encoding_dim, 64),\n",
    "            nn.ReLU(True),\n",
    "            nn.Linear(64, 128),\n",
    "            nn.ReLU(True), \n",
    "            nn.Linear(128, input_dim),\n",
    "        )\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = self.decoder(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "input_dim = 784\n",
    "encoding_dim = 32\n",
    "\n",
    "model = DeepAutoEncoder(input_dim, encoding_dim)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "loss_fn = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_model(model, loss_fn,data_loader=train_loader,epochs=15,optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_reconstructions(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exercise\n",
    "\n",
    "- Change the loss to a BCE loss. \n",
    "\n",
    "- Implement weight sharing.\n",
    "\n",
    "Hint, a rapid google search gives:\n",
    "\n",
    "https://discuss.pytorch.org/t/how-to-create-and-train-a-tied-autoencoder/2585"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convolutional Auto-Encoder\n",
    "\n",
    "Deconvolution are creating checkboard artefacts see [Odena et al.](https://distill.pub/2016/deconv-checkerboard/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class ConvolutionalAutoEncoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ConvolutionalAutoEncoder, self).__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10\n",
    "            nn.ReLU(True),\n",
    "            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5\n",
    "            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3\n",
    "            nn.ReLU(True),\n",
    "            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2\n",
    "        )\n",
    "        self.decoder = nn.Sequential(\n",
    "            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5\n",
    "            nn.ReLU(True),\n",
    "            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15\n",
    "            nn.ReLU(True),\n",
    "            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28\n",
    "        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = self.decoder(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = ConvolutionalAutoEncoder()\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "loss_fn = torch.nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Why is \n",
    "\n",
    "`train_model(model,loss_fn,data_loader=train_loader,epochs=15,optimizer=optimizer)` \n",
    "\n",
    "not working? Make the necessary modification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train_convmodel(model,loss_fn,data_loader=None,epochs=1,optimizer=None):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        for batch_idx, (data, _) in enumerate(train_loader):\n",
    "            #\n",
    "            # your code here\n",
    "            #\n",
    "            if batch_idx % 50 == 0:\n",
    "                print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                    epoch, batch_idx * len(data), len(data_loader.dataset),\n",
    "                    100. * batch_idx / len(data_loader), loss.data.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_convmodel(model, loss_fn,data_loader=train_loader,epochs=15,optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_reconstructions(model, conv=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exercise\n",
    "\n",
    "Implement a denoising AE:\n",
    "\n",
    "![denoising AE](denoisingAE.png)\n",
    "\n",
    "\n",
    "Use previous code and with minimal modifications, transform your AE in a denoising AE.\n",
    "\n",
    "Why is this not working well? Try to improve the denoising performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should obtain results like this:\n",
    "\n",
    "![res_denoise](denoiseAE.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
